'''

'''
import numpy as np
import scipy.sparse as sp
from numpy import genfromtxt

import matplotlib

#matplotlib.rcParams['ps.useafm'] = True
#matplotlib.rcParams['pdf.use14corefonts'] = True
#matplotlib.rcParams['text.usetex'] = True

def SetPlotRC():
    #If fonttype = 1 doesn't work with LaTeX, try fonttype 42.
    plt.rc('pdf',fonttype = 42)
    plt.rc('ps',fonttype = 42)

def ApplyFont(ax):

    ticks = ax.get_xticklabels() + ax.get_yticklabels()

    text_size = 16.0

    for t in ticks:
        t.set_fontname('Times New Roman')
        t.set_fontsize(text_size)

    txt = ax.get_xlabel()
    txt_obj = ax.set_xlabel(txt)
    txt_obj.set_fontname('Times New Roman')
    txt_obj.set_fontsize(text_size)

    txt = ax.get_ylabel()
    txt_obj = ax.set_ylabel(txt)
    txt_obj.set_fontname('Times New Roman')
    txt_obj.set_fontsize(text_size)

    txt = ax.get_title()
    txt_obj = ax.set_title(txt)
    txt_obj.set_fontname('Times New Roman')
    txt_obj.set_fontsize(text_size)


matplotlib.use('Agg')
import matplotlib.pyplot as plt

cola = ['r','b','g', 'k', 'c', 'm', 'y']
colb = ['o', 'x', '<', '>', 'v', '^', '+', 's', 'd', 'p', 'h'] #, '*'
colc = ['-', ':', '--']

def plotter(Y, X=None, errors=None, labels=None, cdict=None, xlabel=None, \
            ylabel=None, filename="myplot", blocked=set(), suffix=".pdf",\
            num_plot=None, xticklabels=None, xtickloc=None, ylim=None, iloc="best"):
    '''
    create plots using numpy matrix
    M = # plots
    N = dim(x): number of points per plot (x-axis)
    Y (M, N) = plot data. One row per figure
    x = index set (N)
    labels = labels for legend, (M)
    cdict = color dictionary, (M)
    xlabel, ylabel (label names)
    xticklabels = labels of ticks on x-axis
    filename (storage location)
    '''
    SetPlotRC()

    if not filename.endswith(suffix):
        filename += suffix

    M, N = Y.shape

    if labels is None:
        labels = np.arange(M)
        lnone=True
    else:
        lnone = False
    havealllabels=1
    if cdict is None:
        havealllabels = 0
    else:
        for il, l in enumerate(labels):
            if(not l in cdict.keys()):
                havealllabels=0
                    
    if cdict is None: # have not yet set cdict dictionary
        cdict = {}
        # get L unique combinations
        A = len(cola)
        B = len(colb)
        C = len(colc)
        rr, cc = sp.find( sp.rand(A, B, float(M+0.5)/(A*B) ))[:2]
        for il, l in enumerate(labels):
            cdict[l] = cola[rr[il]]+colb[cc[il]]+colc[np.random.randint(0, C)]
    else:
        if(havealllabels == 0):
            print ("partial dict detected, attempting to fill. can go into infinite loops, exit manually if that happens!")
            numnewentries = M - len(cdict)
            A = len(cola)
            B = len(colb)
            C = len(colc)
            rr, cc = sp.find( sp.rand(A, B, float(M+0.5)/(A*B) ))[:2]
            for il, l in enumerate(labels):
                if( not l in cdict.keys() ):
                    alreadytaken = 1
                    while(alreadytaken > 0):
                        alreadytaken = 0
                        coll = cola[rr[il]]+colb[cc[il]]+colc[np.random.randint(0, C)]
                        for il2, l2 in enumerate(labels):
                            
                            if(l2 in cdict.keys() and cdict[l2] == coll):
                                alreadytaken =1
                    
                    cdict[l] = coll
            
        print("aaa")
        assert len(cdict) >= M #, "Insufficient number of keys:%d (M=%d)"(len(cdict), M)

    if X is None:
        X = np.arange(N)
    elif X.ndim==1:
        assert len(X) == N, "Insufficient number of x:%d (N=%d)"%(len(X), N)
    else:
        assert X.shape[0] == M, "Insufficient number of x:%d (M=%d)"%(X.shape[0], M)
        assert X.shape[1] == N, "Insufficient number of x:%d (N=%d)"%(X.shape[1], N)


#     fig = plt.figure()
#         #add_axes takes [left, bottom, width, height]
#     border_width = 0.05
#     ax_size = [0.10+border_width, 0.05+border_width,
#                1-4*border_width, 1-3*border_width]
#     ax = fig.add_axes(ax_size)

    if num_plot is None: num_plot = N
    NN = N//num_plot
    for r in range(M):
        if r in blocked:
            continue
        # set X based on ndim
        if X.ndim==2:
            x = X[r]
        else:
            x = X
        print(x)
        print(Y[r])
        print(errors[r])
        plt.errorbar(x, Y[r], fmt=cdict[labels[r]], linewidth=3.0, markevery=NN, markersize=8.0,  yerr=errors[r])
    if xticklabels is not None:
        if xtickloc is None:
            current_xticklocations, current_xticklables_dontcare = plt.xticks()
        else:
            current_xticklocations = xtickloc
        plt.xticks(current_xticklocations, xticklabels)

    if not lnone:
        plt.legend([lab for k, lab in enumerate(labels) if k not in blocked], loc=iloc)
    if xlabel is not None:
        plt.xlabel(xlabel, fontsize=9)#, weight='bold')
    if ylabel is not None:
        plt.ylabel(ylabel, fontsize=9)#, weight='bold')

    if ylim is not None:
        plt.ylim(ylim[0], ylim[1])

    plt.tight_layout()
    ApplyFont(plt.gca())
    plt.savefig(filename)
    plt.clf()
    plt.close("all")
    return cdict
#
# def main_try():
#     M, N = 3, 20
#     labels = ["row"+`k` for k in range(M) ]
#     data   = np.random.randn(M, N)
#     #filename = "table.txt"
#     filename="/home/rajiv/Dropbox/code/liclipseWorkspace/sparsePriorDesign/abc"
#     plotter(data, labels=labels, filename=filename,xlabel='X', ylabel='Y')

def get_plotdata(filename):
    dat = genfromtxt(filename, delimiter=',', dtype=str)
    xticklabels = dat[0, 1:]
    labels = dat[1:,0]
    to_plot = dat[1:,1:]
    to_plot = to_plot.astype(np.float)
    return(to_plot, xticklabels, labels)

def get_plotdata2(filename):
    dat = genfromtxt(filename, delimiter=',', dtype=str)
    xticklabels = dat[1:, 0]
    labels = dat[0,1:]
    to_plot = dat[1:,1:]
    to_plot = to_plot.astype(np.float)
    return(to_plot, xticklabels, labels)

def main_plot_toy():
    prefix="Downloads/"

    dictload = np.load(prefix+'newdict2.npz', allow_pickle=True)
    cdict=dictload['cdict'].item()


    # noise variances
    filename = prefix + 'results.csv'
    data, xticklabels, labels = get_plotdata(filename)

    filename = prefix + 'errors.csv'
    errors, xticklabels_ignore, labels_ignore = get_plotdata(filename)
    print

    cdict=plotter( data, errors=errors, labels=labels, filename=prefix+'res.pdf',cdict=cdict, ylabel="$\mu_y(X)$", xlabel="Number of sketched rows", xtickloc=[0,1,2,3,4], xticklabels=xticklabels, ylim=[-2,14.0] )
    #
    # filename = prefix + 'error.csv'
    # data, xticklabels, labels = get_plotdata(filename)
    #
    # plotter(data, cdict=cdict, labels=labels, filename=prefix+'error.pdf', xlabel="Noise Variance",   xtickloc=[0,1,2,3],  ylabel="Reconstruction Error ", xticklabels=xticklabels, ylim=[0,20])
    np.savez(prefix+'newdict2.npz',cdict=cdict)
    
    
#     # noise variances
#     filename = prefix + 'diff.n.r2.csv'
#     data, xticklabels, labels = get_plotdata(filename)
# 
#     cdict=plotter(data, cdict=cdict, labels=labels, filename='diff.n.r2.pdf', xlabel="n:k", ylabel="$R^2$", xticklabels=xticklabels, xtickloc=[0,1,2,3], ylim=[-0.1,1.1])
# 
#     filename = prefix + 'diff.n.suppauc.csv'
#     data, xticklabels, labels = get_plotdata(filename)

#    plotter(data, cdict=cdict, labels=labels, filename='diff.n.suppauc.pdf', xlabel="n:k", ylabel="Support AUC",  xtickloc=[0,1,2,3], xticklabels=xticklabels, ylim=[0.45,1.05])

def main_plot_real():
    prefix = "Downloads/"

    dictload = np.load(prefix + 'newdict2.npz', allow_pickle=True)
    cdict = dictload['cdict'].item()

    # noise variances
    filename = prefix + 'resultsreal.csv'
    data, xticklabels, labels = get_plotdata(filename)

    filename = prefix + 'errorsreal.csv'
    errors, xticklabels_ignore, labels_ignore = get_plotdata(filename)
    print

    cdict = plotter(data, errors=errors, labels=labels, filename=prefix + 'resreal.pdf', cdict=cdict, ylabel="$\mu_y(X)$",
                    xlabel="Number of sketched rows", xtickloc=[0, 1, 2, 3], xticklabels=xticklabels,
                    ylim=[-2, 12500])
    #
    # filename = prefix + 'error.csv'
    # data, xticklabels, labels = get_plotdata(filename)
    #
    # plotter(data, cdict=cdict, labels=labels, filename=prefix+'error.pdf', xlabel="Noise Variance",   xtickloc=[0,1,2,3],  ylabel="Reconstruction Error ", xticklabels=xticklabels, ylim=[0,20])
    np.savez(prefix + 'newdict2.npz', cdict=cdict)


if __name__ == '__main__':

   # main_plot_toy()
    main_plot_real()
